# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Convert coco data to mindrecord format."""

import os

import numpy as np
from mindspore.mindrecord import FileWriter
from pycocotools.coco import COCO

from mindvision.engine.class_factory import ClassFactory, ModuleType

from mindvision.detection.datasets.coco import has_valid_annotation
from mindvision.detection.datasets.utils.classes import get_classes


def ann_to_mask(ann, height, width):
    """Convert annotation to RLE and then to binary mask."""
    from pycocotools import mask as maskHelper
    segm = ann['segmentation']
    if isinstance(segm, list):
        rles = maskHelper.frPyObjects(segm, height, width)
        rle = maskHelper.merge(rles)
    elif isinstance(segm['counts'], list):
        rle = maskHelper.frPyObjects(segm, height, width)
    else:
        rle = ann['segmentation']
    m = maskHelper.decode(rle)
    return m


@ClassFactory.register(ModuleType.DATASET)
class Coco2MindRecord:
    """
    convert coco dataset to mind record file.

    Args:
        data_root (str) : The image dir.
        ann_file (str) : The ann file dir.
        mindrecord_file (str) : The output mindrecord file
    """

    COCO_CLASSES = list(get_classes(label="COCO"))
    COCO_CLASSES.insert(0, "background")

    def __init__(self, root, ann_file, mindrecord_dir,
                 remove_images_without_annos=True,
                 filter_crowd_anno=True, with_mask=False):
        """Constructor for Coco2MindRecord"""
        self.data_root = root
        self.ann_file = ann_file
        self.mindrecord_dir = mindrecord_dir
        self.with_mask = with_mask
        self.remove_images_without_annos = remove_images_without_annos
        self.filter_crowd_anno = filter_crowd_anno

    def create_coco_label(self):
        """Get image path and annotation from COCO."""
        coco = COCO(self.ann_file)

        classes_dict = {}
        cat_ids = coco.loadCats(coco.getCatIds())
        for cat in cat_ids:
            classes_dict[cat["id"]] = cat["name"]

        # Classes need to train or test.
        train_cls = self.COCO_CLASSES
        train_cls_dict = {}
        if train_cls is not None:
            for i, cls in enumerate(train_cls):
                train_cls_dict[cls] = i
        else:
            for cat in cat_ids:
                train_cls_dict[cat["name"]] = cat["id"]

        image_ids = coco.getImgIds()
        image_files = []
        image_id = []
        image_anno_dict = {}
        if self.with_mask:
            masks = {}
            masks_shape = {}
        for img_id in image_ids:
            image_info = coco.loadImgs(img_id)
            file_name = image_info[0]["file_name"]
            anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=None)
            anno = coco.loadAnns(anno_ids)
            # filter empty annotations
            if self.remove_images_without_annos and not has_valid_annotation(anno):
                continue

            image_path = os.path.join(self.data_root, file_name)
            annos = []
            instance_masks = []
            image_height = coco.imgs[img_id]["height"]
            image_width = coco.imgs[img_id]["width"]

            for label in anno:
                bbox = label["bbox"]
                class_name = classes_dict[label["category_id"]]

                # filter crowd
                if self.filter_crowd_anno and label["iscrowd"] == 1:
                    continue

                if train_cls is not None and class_name not in train_cls:
                    continue

                x1, x2 = bbox[0], bbox[0] + bbox[2]
                y1, y2 = bbox[1], bbox[1] + bbox[3]
                if self.with_mask:
                    m = ann_to_mask(label, image_height, image_width)
                    instance_masks.append(m)
                annos.append(
                    [x1, y1, x2, y2] + [train_cls_dict[class_name]] + [
                        int(label["iscrowd"])])
            image_files.append(image_path)
            image_id.append(img_id)
            if annos:
                image_anno_dict[image_path] = np.array(annos)
                if self.with_mask:
                    instance_masks = np.stack(instance_masks, axis=0).astype(np.bool)
                    masks[image_path] = np.array(instance_masks).tobytes()
                    masks_shape[image_path] = np.array(instance_masks.shape, dtype=np.int32)
            else:
                image_anno_dict[image_path] = np.array([0, 0, 0, 0, 0, 1])
                if self.with_mask:
                    masks[image_path] = np.zeros([1, image_height, image_width], dtype=np.bool).tobytes()
                    masks_shape[image_path] = np.array([1, image_height, image_width], dtype=np.int32)
        if self.with_mask:
            return image_id, image_files, image_anno_dict, masks, masks_shape
        return image_id, image_files, image_anno_dict

    def data_to_mindrecord_byte_image(self, prefix="coco.mindrecord", file_num=8):
        """Create MindRecord file."""
        if not os.path.exists(self.mindrecord_dir):
            os.mkdir(self.mindrecord_dir)

        mindrecord_path = os.path.join(self.mindrecord_dir, prefix)
        writer = FileWriter(mindrecord_path, file_num)
        if self.with_mask:
            image_id, image_files, image_anno_dict, masks, masks_shape = self.create_coco_label()
            maskrcnn_json = {
                "img_id": {"type": "int32", "shape": [1]},
                "image": {"type": "bytes"},
                "annotation": {"type": "int32", "shape": [-1, 6]},
                "mask": {"type": "bytes"},
                "mask_shape": {"type": "int32", "shape": [-1]},
            }
            writer.add_schema(maskrcnn_json, "maskrcnn_json")
        else:
            image_id, image_files, image_anno_dict = self.create_coco_label()
            fasterrcnn_json = {
                "img_id": {"type": "int32", "shape": [1]},
                "image": {"type": "bytes"},
                "annotation": {"type": "int32", "shape": [-1, 6]},
            }
            writer.add_schema(fasterrcnn_json, "fasterrcnn_json")

        for idx, image_name in enumerate(image_files):
            with open(image_name, 'rb') as f:
                img = f.read()
            annos = np.array(image_anno_dict[image_name], dtype=np.int32)
            img_id = np.array([image_id[idx]], dtype=np.int32)
            if self.with_mask:
                mask = masks[image_name]
                mask_shape = masks_shape[image_name]
                row = {"img_id": img_id, "image": img, "annotation": annos, "mask": mask, "mask_shape": mask_shape}
            else:
                row = {"img_id": img_id, "image": img, "annotation": annos}
            writer.write_raw_data([row])
        writer.commit()

    def __call__(self, prefix='coco.mindrecord'):
        """ Write mindrecord file """
        mindrecord_path = os.path.join(self.mindrecord_dir, prefix)
        if os.path.exists(mindrecord_path + "0"):
            return
        self.data_to_mindrecord_byte_image()
